{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Seq2Seq-Torch",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "vaOijL5peTmi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor\n",
        "import torch\n",
        "import numpy as np\n",
        "import torch.nn as nn\n",
        "import torch.utils.data as Data\n",
        "\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "# S: Symbol that shows starting of decoding input\n",
        "# E: Symbol that shows starting of decoding output\n",
        "# ?: Symbol that will fill in blank sequence if current batch data size is short than n_step"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5QaLakJpeoCX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "letter = [c for c in 'SE?abcdefghijklmnopqrstuvwxyz']\n",
        "letter2idx = {n: i for i, n in enumerate(letter)}\n",
        "\n",
        "seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]\n",
        "\n",
        "# Seq2Seq Parameter\n",
        "n_step = max([max(len(i), len(j)) for i, j in seq_data]) # max_len(=5)\n",
        "n_hidden = 128\n",
        "n_class = len(letter2idx) # classfication problem\n",
        "batch_size = 3"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tVdF8M6kfPNW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def make_data(seq_data):\n",
        "    enc_input_all, dec_input_all, dec_output_all = [], [], []\n",
        "\n",
        "    for seq in seq_data:\n",
        "        for i in range(2):\n",
        "            seq[i] = seq[i] + '?' * (n_step - len(seq[i])) # 'man??', 'women'\n",
        "\n",
        "        enc_input = [letter2idx[n] for n in (seq[0] + 'E')] # ['m', 'a', 'n', '?', '?', 'E']\n",
        "        dec_input = [letter2idx[n] for n in ('S' + seq[1])] # ['S', 'w', 'o', 'm', 'e', 'n']\n",
        "        dec_output = [letter2idx[n] for n in (seq[1] + 'E')] # ['w', 'o', 'm', 'e', 'n', 'E']\n",
        "\n",
        "        enc_input_all.append(np.eye(n_class)[enc_input])\n",
        "        dec_input_all.append(np.eye(n_class)[dec_input])\n",
        "        dec_output_all.append(dec_output) # not one-hot\n",
        "\n",
        "    # make tensor\n",
        "    return torch.Tensor(enc_input_all), torch.Tensor(dec_input_all), torch.LongTensor(dec_output_all)\n",
        "\n",
        "'''\n",
        "enc_input_all: [6, n_step+1 (because of 'E'), n_class]\n",
        "dec_input_all: [6, n_step+1 (because of 'S'), n_class]\n",
        "dec_output_all: [6, n_step+1 (because of 'E')]\n",
        "'''\n",
        "enc_input_all, dec_input_all, dec_output_all = make_data(seq_data)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9TVfD38FmJzZ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class TranslateDataSet(Data.Dataset):\n",
        "    def __init__(self, enc_input_all, dec_input_all, dec_output_all):\n",
        "        self.enc_input_all = enc_input_all\n",
        "        self.dec_input_all = dec_input_all\n",
        "        self.dec_output_all = dec_output_all\n",
        "    \n",
        "    def __len__(self): # return dataset size\n",
        "        return len(self.enc_input_all)\n",
        "    \n",
        "    def __getitem__(self, idx):\n",
        "        return self.enc_input_all[idx], self.dec_input_all[idx], self.dec_output_all[idx]\n",
        "\n",
        "loader = Data.DataLoader(TranslateDataSet(enc_input_all, dec_input_all, dec_output_all), batch_size, True)"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5_Xm1E_kf5FQ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "outputId": "e702e0a3-5222-4e5b-a37f-779cafd0d32c"
      },
      "source": [
        "# Model\n",
        "class Seq2Seq(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Seq2Seq, self).__init__()\n",
        "        self.encoder = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5) # encoder\n",
        "        self.decoder = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5) # decoder\n",
        "        self.fc = nn.Linear(n_hidden, n_class)\n",
        "\n",
        "    def forward(self, enc_input, enc_hidden, dec_input):\n",
        "        # enc_input(=input_batch): [batch_size, n_step+1, n_class]\n",
        "        # dec_inpu(=output_batch): [batch_size, n_step+1, n_class]\n",
        "        enc_input = enc_input.transpose(0, 1) # enc_input: [n_step+1, batch_size, n_class]\n",
        "        dec_input = dec_input.transpose(0, 1) # dec_input: [n_step+1, batch_size, n_class]\n",
        "\n",
        "        # h_t : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
        "        _, h_t = self.encoder(enc_input, enc_hidden)\n",
        "        # outputs : [n_step+1, batch_size, num_directions(=1) * n_hidden(=128)]\n",
        "        outputs, _ = self.decoder(dec_input, h_t)\n",
        "\n",
        "        model = self.fc(outputs) # model : [n_step+1, batch_size, n_class]\n",
        "        return model\n",
        "\n",
        "model = Seq2Seq().to(device)\n",
        "criterion = nn.CrossEntropyLoss().to(device)\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:50: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
            "  \"num_layers={}\".format(dropout, num_layers))\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "u_JW2thNezfH",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 185
        },
        "outputId": "7d589e7b-f35a-48a1-c125-1584d646f1e3"
      },
      "source": [
        "for epoch in range(5000):\n",
        "  for enc_input_batch, dec_input_batch, dec_output_batch in loader:\n",
        "      # make hidden shape [num_layers * num_directions, batch_size, n_hidden]\n",
        "      h_0 = torch.zeros(1, batch_size, n_hidden).to(device)\n",
        "\n",
        "      (enc_input_batch, dec_intput_batch, dec_output_batch) = (enc_input_batch.to(device), dec_input_batch.to(device), dec_output_batch.to(device))\n",
        "      # enc_input_batch : [batch_size, n_step+1, n_class]\n",
        "      # dec_intput_batch : [batch_size, n_step+1, n_class]\n",
        "      # dec_output_batch : [batch_size, n_step+1], not one-hot\n",
        "      pred = model(enc_input_batch, h_0, dec_intput_batch)\n",
        "      # pred : [n_step+1, batch_size, n_class]\n",
        "      pred = pred.transpose(0, 1) # [batch_size, n_step+1(=6), n_class]\n",
        "      loss = 0\n",
        "      for i in range(len(dec_output_batch)):\n",
        "          # pred[i] : [n_step+1, n_class]\n",
        "          # dec_output_batch[i] : [n_step+1]\n",
        "          loss += criterion(pred[i], dec_output_batch[i])\n",
        "      if (epoch + 1) % 1000 == 0:\n",
        "          print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
        "          \n",
        "      optimizer.zero_grad()\n",
        "      loss.backward()\n",
        "      optimizer.step()"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 1000 cost = 0.002763\n",
            "Epoch: 1000 cost = 0.002548\n",
            "Epoch: 2000 cost = 0.000556\n",
            "Epoch: 2000 cost = 0.000524\n",
            "Epoch: 3000 cost = 0.000165\n",
            "Epoch: 3000 cost = 0.000163\n",
            "Epoch: 4000 cost = 0.000057\n",
            "Epoch: 4000 cost = 0.000054\n",
            "Epoch: 5000 cost = 0.000017\n",
            "Epoch: 5000 cost = 0.000020\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qS-Uwpwte1AO",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "a0c61c01-a56a-49ea-c427-20c28f73f180"
      },
      "source": [
        "# Test\n",
        "def translate(word):\n",
        "    enc_input, dec_input, _ = make_data([[word, '?' * n_step]])\n",
        "    enc_input, dec_input = enc_input.to(device), dec_input.to(device)\n",
        "    # make hidden shape [num_layers * num_directions, batch_size, n_hidden]\n",
        "    hidden = torch.zeros(1, 1, n_hidden).to(device)\n",
        "    output = model(enc_input, hidden, dec_input)\n",
        "    # output : [n_step+1, batch_size, n_class]\n",
        "\n",
        "    predict = output.data.max(2, keepdim=True)[1] # select n_class dimension\n",
        "    decoded = [letter[i] for i in predict]\n",
        "    translated = ''.join(decoded[:decoded.index('E')])\n",
        "\n",
        "    return translated.replace('?', '')\n",
        "\n",
        "print('test')\n",
        "print('man ->', translate('man'))\n",
        "print('mans ->', translate('mans'))\n",
        "print('king ->', translate('king'))\n",
        "print('black ->', translate('black'))\n",
        "print('up ->', translate('up'))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "test\n",
            "man -> women\n",
            "mans -> women\n",
            "king -> queen\n",
            "black -> white\n",
            "up -> down\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AE5gntwU-V-M",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}